// Copyright 2019 eBay Inc.
// Primary authors: Simon Fell, Diego Ongaro,
//                  Raymond Kroeker, and Sathish Kandasamy.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package exec

import (
	"context"
	"fmt"

	"github.com/ebay/akutan/query/parser"
	"github.com/ebay/akutan/query/planner/plandef"
	"github.com/ebay/akutan/util/errors"
	"github.com/ebay/akutan/util/parallel"
	"github.com/sirupsen/logrus"
)

// newLoopJoin returns an operator that performs a join by taking a chunk of rows
// from the left input and binding them to an execution of the right input. This
// is similar to an index nested loop join in sql.
//
// The output will depend on the Specificity setting of the join, for required
// matches the output will consists of facts sets constructed from the left &
// right values where the join key matches. For optional matches, the output
// will always generate a factset for the left side, and if there's a matching
// right side, it'll include that in the result.
func newLoopJoin(op *plandef.LoopJoin, inputNodes []queryOperator) operator {
	if len(inputNodes) != 2 {
		panic(fmt.Sprintf("loopJoin operation with unexpected inputs: %v", len(inputNodes)))
	}
	panicOnInvalidSpecificity(op.Specificity)
	columns, joiner := joinedColumns(inputNodes[0].columns(), inputNodes[1].columns())

	return &loopJoin{
		def:    op,
		left:   inputNodes[0],
		right:  inputNodes[1],
		output: columns,
		joiner: joiner,
	}
}

type loopJoin struct {
	def    *plandef.LoopJoin
	left   queryOperator
	right  queryOperator
	output Columns
	joiner func(left, right []Value) []Value
}

func (l *loopJoin) operator() plandef.Operator {
	return l.def
}

func (l *loopJoin) columns() Columns {
	return l.output
}

// As queryOp's support bulk binding, we can execute the right side once for each chunk
// of rows we got from the left side, rather than for each row
func (l *loopJoin) execute(ctx context.Context, binder valueBinder, res results) error {
	leftResCh := make(chan ResultChunk, 4)
	wait := parallel.GoCaptureError(func() error {
		joiner := loopJoiner{
			outputTo: res,
			joiner:   l.joiner,
		}
		var performJoin func(ctx context.Context, left *ResultChunk, rightResCh <-chan ResultChunk)
		switch l.def.Specificity {
		case parser.MatchRequired:
			performJoin = joiner.eqJoin
		case parser.MatchOptional:
			performJoin = joiner.leftJoin
		default:
			logrus.Panicf("LoopJoin with unexpected Specificity value of %v", l.def.Specificity)
		}
		for leftChunk := range leftResCh {
			// we got a chunk of rows from the left, now we will execute the
			// rightOp binding any values needed from left chunk. The
			// resulting joined results are added to the resultBuilder.
			rightResCh := make(chan ResultChunk, 4)
			rightWait := parallel.Go(func() {
				performJoin(ctx, &leftChunk, rightResCh)
			})
			// Combine the new bindings available from leftChunk with the
			// source bindings in binder that came from the parent Op. The
			// collection of bindings flows down the right side of the tree.
			rhsBinder := &binderWithParent{
				parent: binder,
				child:  &leftChunk,
			}
			err := l.right.run(ctx, rhsBinder, rightResCh)
			rightWait()
			if err != nil {
				return err
			}
		}
		return nil
	})
	errLeft := l.left.run(ctx, binder, leftResCh)
	errRight := wait()
	return errors.Any(errLeft, errRight)
}

// loopJoiner is used to process data coming into the join and generate the
// relevant outputs. outputTo will be called for all results generated by this
// this joiner.
type loopJoiner struct {
	outputTo results
	joiner   func(left, right []Value) []Value
}

// eqJoin generates results for an equal join i.e. Specificity=Required. For
// each right factSet, its combined with the left factSet that generated it and
// the result is sent to the builder.
func (lj *loopJoiner) eqJoin(ctx context.Context, left *ResultChunk, rightResCh <-chan ResultChunk) {
	// Its important to understand that offsets from the 2 inputs to the join
	// are local to that join. The offset that should be reported in the join
	// results is based on the bulkOffsets supplied by the valueBinder when the
	// join was executed. The left side of the loop join is passed the binder
	// the join was executed with. The right side of the loop join is passed a
	// binder based on the results from the left side.
	//
	// e.g. given a query plan of
	//      LoopJoin_1 ?x
	//          Lookup_1 ?x
	//          LoopJoin_2 ?y
	//              Lookup_2 $x ?y
	//              Lookup_3 $y ?r
	//
	// and some test data consisting of
	//  ?x  ?y       ?y   ?r
	//  -------      -------
	//   A   1        2   Z
	//   A   2       11   Q
	//   B  11
	//
	// a execution sequence would look like
	//  run LoopJoin_1 ?x (defaultBinder)
	//  run Lookup_1      (defaultBinder)
	//  results Lookup_1: [?x=A offset=0, ?x=B offset=0]
	//  run LoopJoin_2 ?y (bind [0]$x=A, [1]$x=B)
	//  run Lookup_2      (bind [0]$x=A, [1]$x=B)
	//  results Lookup_2: [?y=1 offset=0, ?y=2 offset=0, ?y=11 offset=1]
	//                    (this means ?y=1 and ?y=2 were from when $x=A and ?y=11 is from when $x=B)
	//  run Lookup_3      (bind [0]$x=A,$y=1, [1]$x=A,$y=2, [2]$x=B,$y=11)
	//  results Lookup_3: [?r=Z offset=1, ?r=Q offset=2]
	//  results LoopJoin_2 ?y: [?x=A,?y=2,?r=Z offset=0, ?x=B,?y=11,?r=Q offset=1]
	//  results LoopJoin_1 ?x: [?x=A,?y=2,?r=Z offset=0, ?x=B,?y=11,?r=Q offset=0]
	//
	//
	// Looking at the inner most LoopJoin_2 ?y / Lookup_2 / Lookup_3 operators
	// in more detail.
	//
	//                                      |
	//                                      | 1.execute
	//                                      v
	//                            +--------------------+--------------------------+
	//                            | LoopJoin_2 ?y      | Results                  |
	//                            | input value binder |                          |
	//                            | offset 0 $x=A      | offset 0 ?x=A ?y=2  ?r=Z |
	//      +---------------------| offset 1 $x=B      | offset 1 ?x=B ?y=11 ?r=Q |
	//      |                     +--------------------+--------------------------+
	//      |                       ^               |           ^
	//      | 2.execute left        | 3.results     |           |
	//      v                       |               | 4.execute | 5.results
	//  +--------------------+--------------------+ | right     |
	//  | Lookup_2           | Results            | |           |
	//  | input value binder | offset 0 ?y=1      | |           |
	//  | offset 0 $x=A      | offset 0 ?y=2      | |           |
	//  | offset 1 $x=B      | offset 1 ?y=11     | |           |
	//  +--------------------+--------------------+ |           |
	//                                              v           |
	//                                  +--------------------+------------------+
	//                                  | Lookup_3           | Results          |
	//                                  | input value binder | offset 1 ?r=Z    |
	//                                  | offset 0 $x=A $y=1 | offset 2 ?r=Q    |
	//                                  | offset 1 $x=A $y=2 |                  |
	//                                  | offset 2 $x=B $y=11|                  |
	//                                  +--------------------+------------------+
	//
	// Once the results from the right side of the join are available, the join
	// operator now has to create the joined results.
	//  +---------------+---------------+-----------------+--------------------------+
	//  | Join          | Left Results  |  Right Results  | Join Results             |
	//  | Input Binder  |               |                 |                          |
	//  +---------------+---------------+-----------------+--------------------------|
	//  | offset 0 $x=A | offset 0 ?y=1 |                 |                          |
	//  |               | offset 0 ?y=2 | offset 1 ?r=Z   | offset 0 ?x=A ?y=2 ?r=Z  |
	//  | offset 1 $x=B | offset 1 ?y=11| offset 2 ?r=Q   | offset 1 ?x=B ?y=11 ?r=Q |
	//  +---------------+---------------+-----------------+--------------------------+
	//
	// The first right result has offset 1, which means it was from the right
	// value binder offset 1. So this goes with the 2nd row in the left results.
	// The left row has offset 0, so the resulting output row should have offset
	// 0.
	//
	// The 2nd right result has offset 2, which means it was from the right value
	// binder offset 2. So this goes with the last row in the left results. This
	// left row has offset 1, so the resulting output row should have offset 1.
	//
	lj.runEqJoin(left, rightResCh, func(leftRowIndex uint32, res FactSet, rowValues []Value) {
		// this result was from row 'leftRowIndex' in 'left'
		// the output should have the offset that the left row has.
		lj.outputTo.add(ctx, left.offsets[leftRowIndex], res, rowValues)
	})
}

// leftJoin generates results for a left join i.e. Specificity=Optional. For
// each right factSet, its combined with the left factSet row that generated it
// and the result is sent to the builder. It tracks which of the left rows have
// had results from it. Once all the right rows are received, it'll generate
// additional results for any left rows that didn't get any right rows.
func (lj *loopJoiner) leftJoin(ctx context.Context, left *ResultChunk, rightResCh <-chan ResultChunk) {
	// see comments in eqJoin about how offsets work, they apply to leftJoin as
	// well.
	leftIndexesUsed := make([]bool, left.len())
	lj.runEqJoin(left, rightResCh, func(leftRowIndex uint32, res FactSet, rowValues []Value) {
		leftIndexesUsed[leftRowIndex] = true
		lj.outputTo.add(ctx, left.offsets[leftRowIndex], res, rowValues)
	})
	for idx, used := range leftIndexesUsed {
		if !used {
			lj.outputTo.add(ctx, left.offsets[idx], left.Facts[idx], lj.joiner(left.Row(idx), nil))
		}
	}
}

// loopJoinerResult defines the callback function that runEqJoin will call to
// notify the caller of results. leftRowIndex indicate the index into the left
// chunk's rows that the result was generated from. Offsets reported by the
// inputs to the join need to be correctly resolved to the offset in the join
// binder, they are not the same.
type loopJoinerResult func(leftRowIndex uint32, fs FactSet, rowValues []Value)

// runEqJoin reads from rightResCh, calculates the joined value, and publish the
// result to the resultFn callback. This is a helper used by other functions in
// loopJoiner. This function doesn't return until rightResCh has been fully
// processed.
func (lj *loopJoiner) runEqJoin(left *ResultChunk, rightResCh <-chan ResultChunk, resultFn loopJoinerResult) {
	for rightChunk := range rightResCh {
		for i, offset := range rightChunk.offsets {
			// offset is the offset into the left side that created the result
			// (by virtue of how the binder passed to the right is constructed)
			leftVal := left.Row(int(offset))
			resultFn(
				offset,
				joinFactSets(left.Facts[offset], rightChunk.Facts[i]),
				lj.joiner(leftVal, rightChunk.Row(i)))
		}
	}
}

// joinedColumns returns the union of the supplied columns along with a mapper function
// that can perform the join of the values.
func joinedColumns(left, right Columns) (Columns, func(left, right []Value) []Value) {
	// rightIdx captures the indexes into the output values the values from the
	// right side should appear. The left side is always a simple 1:1 mapping
	rightIdx := make([]int, len(right))
	outCols := append(Columns(nil), left...)
	for i, c := range right {
		if _, exists := outCols.IndexOf(c); exists {
			rightIdx[i] = -1
		} else {
			outCols = append(outCols, c)
			rightIdx[i] = len(outCols) - 1
		}
	}
	return outCols, func(leftVals, rightVals []Value) []Value {
		if len(leftVals) != len(left) {
			panic(fmt.Sprintf("row joiner passed unexpected sized left input row of %d, should be %d", len(leftVals), len(left)))
		}
		// an empty slice can be passed for rightVals to do a left join
		if len(rightVals) != 0 && len(rightVals) != len(right) {
			panic(fmt.Sprintf("row joiner passed unexpected sized right input row of %d, should be %d", len(rightVals), len(right)))
		}
		row := make([]Value, len(outCols))
		copy(row, leftVals)
		for i, rightVal := range rightVals {
			if rightIdx[i] >= 0 {
				row[rightIdx[i]] = rightVal
			}
		}
		return row
	}
}
